class database(object):
    def __init__(self, dataset_name, num_classes, num_class_samples, input_shape, train_root, train_list, test_root,
                 test_list):
        self.dataset_name = dataset_name
        self.num_classes = num_classes
        self.num_class_samples = num_class_samples
        self.input_shape = input_shape
        self.train_root = train_root
        self.train_list = train_list
        self.test_root = test_root
        self.test_list = test_list


IITD = database('IITD', 231, 6, (1, 128, 128),
                   '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/IITD/train_RFN/',
                   '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/IITD/train_RFN.txt',
                   '/home/yjy/dataset/IITD-ROI-DB-128',
                   '/home/yjy/PycharmProjects/PalmROI/data/Datasets/IITD/test_db.txt')

IITD_Zhang = database('IITD_Zhang', 231, 6, (1, 128, 128),
                      '/home/yjy/dataset/IITD-ROI-Zhang/',
                      '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/IITD/train_dprime.txt',
                      '/home/yjy/dataset/IITD-ROI-Zhang',
                      '/home/yjy/PycharmProjects/PalmROI/data/Datasets/IITD/test_zhang.txt')

IITD_Ito = database('IITD_Ito', 231, 6, (1, 128, 128),
                    '/home/yjy/dataset/IITD-ROI-Zhang/',
                    '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/IITD/train_dprime.txt',
                    '/home/yjy/dataset/IITD-ROI-Ito',
                    '/home/yjy/PycharmProjects/PalmROI/data/Datasets/IITD/test_ito.txt')

IITD_Mahdieh = database('IITD_Mahdieh', 231, 6, (1, 128, 128),
                    '/home/yjy/dataset/IITD-ROI-Mahdieh/',
                    '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/IITD/train_dprime.txt',
                    '/home/yjy/dataset/IITD-ROI-Mahdieh',
                    '/home/yjy/PycharmProjects/PalmROI/data/Datasets/IITD/test_Mahdieh.txt')

IITD_NN = database('IITD_NN', 231, 6, (1, 128, 128),
                   '/home/yjy/dataset/IITD-ROI-NN/',
                   '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/IITD/train_dprime.txt',
                   '/home/yjy/dataset/IITD-ROI-NN',
                   '/home/yjy/PycharmProjects/PalmROI/data/Datasets/IITD/test_NN.txt')

Tongji = database('Tongji', 301, 10, (1, 128, 128),
                  '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/Tongji/train/',
                  '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/Tongji/train.txt',
                  '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/Tongji/test/',
                  '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/Tongji/test.txt')

Tongji_Zhang = database('Tongji_Zhang', 301, 10, (1, 128, 128),
                        '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/Tongji/train/',
                        '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/Tongji/train.txt',
                        '/home/yjy/dataset/Tongji-ROI-Zhang_withdefault',
                        '/home/yjy/PycharmProjects/PalmROI/data/Datasets/Tongji/test_ito.txt')

Tongji_NN = database('Tongji_NN', 301, 10, (1, 128, 128),
                     '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/Tongji/train/',
                     '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/Tongji/train.txt',
                     '/home/yjy/dataset/Tongji-ROI-NN',
                     '/home/yjy/PycharmProjects/PalmROI/data/Datasets/Tongji/test_ito.txt')

Tongji_Ito = database('Tongji_Ito', 301, 10, (1, 128, 128),
                      '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/Tongji/train/',
                      '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/Tongji/train.txt',
                      '/home/yjy/dataset/Tongji-ROI-Ito',
                      '/home/yjy/PycharmProjects/PalmROI/data/Datasets/Tongji/test_ito.txt')

Tongji_Mahdieh = database('Tongji_Mahdieh', 301, 10, (1, 128, 128),
                      '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/Tongji/train/',
                      '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/Tongji/train.txt',
                      '/home/yjy/dataset/Tongji-ROI-Mahdieh',
                      '/home/yjy/PycharmProjects/PalmROI/data/Datasets/Tongji/test_Mahdieh.txt')

GPDS = database('GPDS', 51, 10, (1, 128, 128),
                '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/GPDS/train/',
                '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/GPDS/train.txt',
                '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/GPDS/test/',
                '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/GPDS/test.txt')

CASIA_Zhang = database('CASIA_Zhang', 313, 8, (1, 128, 128),
                       '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/CASIA/train_SDDLM/',
                       '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/CASIA/train_SDDLM.txt',
                       '/home/yjy/dataset/CASIA-ROI-Zhang',
                       '/home/yjy/PycharmProjects/PalmROI/data/Datasets/CASIA/test_dprime.txt')

CASIA_Ito = database('CASIA_Ito', 313, 8, (1, 128, 128),
                     '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/CASIA/train_SDDLM/',
                     '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/CASIA/train_SDDLM.txt',
                     '/home/yjy/dataset/CASIA-ROI-Ito',
                     '/home/yjy/PycharmProjects/PalmROI/data/Datasets/CASIA/test_dprime.txt')

CASIA_Mahdieh = database('CASIA_Mahdieh', 313, 8, (1, 128, 128),
                     '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/CASIA/train_SDDLM/',
                     '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/CASIA/train_SDDLM.txt',
                     '/home/yjy/dataset/CASIA-ROI-Mahdieh',
                     '/home/yjy/PycharmProjects/PalmROI/data/Datasets/CASIA/test_bmp.txt')

CASIA_RotNN = database('CASIA_NN', 313, 8, (1, 128, 128),
                    '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/CASIA/train_SDDLM/',
                    '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/CASIA/train_SDDLM.txt',
                    '/home/yjy/dataset/CASIA-ROI-RotNN',
                    '/home/yjy/PycharmProjects/PalmROI/data/Datasets/CASIA/test_dprime.txt')

CASIA_NN = database('CASIA_NN', 313, 8, (1, 128, 128),
                    '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/CASIA/train_SDDLM/',
                    '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/CASIA/train_SDDLM.txt',
                    '/home/yjy/dataset/CASIA-ROI-NN',
                    '/home/yjy/PycharmProjects/PalmROI/data/Datasets/CASIA/test_dprime.txt')

CASIA_STN = database('CASIA_STN', 313, 8, (1, 128, 128),
                     '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/CASIA/train_SDDLM/',
                     '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/CASIA/train_SDDLM.txt',
                     '/home/yjy/dataset/CASIA-ROI-STN',
                     '/home/yjy/PycharmProjects/PalmROI/data/Datasets/CASIA/test_dprime.txt')

db = IITD


class Config(object):
    num_classes = db.num_classes
    numClassSamples = db.num_class_samples
    dataset = db.dataset_name
    input_shape = db.input_shape
    train_root = db.train_root
    train_list = db.train_list
    lfw_root = db.test_root
    lfw_test_list_single = db.test_list
    lfw_test_list = '/home/yjy/PycharmProjects/arcface-pytorch/data/Datasets/IITD/test_pair.txt'

    env = 'default'
    backbone = 'palmnet_yjy'  # resnet_palm18
    metric = 'arc_margin'  # 'arc_margin'
    easy_margin = False
    use_se = True
    loss = 'focal_loss'
    code_length = 512  # 1024

    display = False
    finetune = False

    # train_root = '/tmp/pycharm_project_615/data/Datasets/Tongji/train/'
    # train_list = '/tmp/pycharm_project_615/data/Datasets/Tongji/train.txt'
    #
    # lfw_root = '/tmp/pycharm_project_615/data/Datasets/Tongji/test/'
    # lfw_test_list_single = '/tmp/pycharm_project_615/data/Datasets/Tongji/test.txt'

    checkpoints_path = 'checkpoints'
    test_model_path = 'checkpoints/palmnet_yjy_5000_IITD.pth'
    save_interval = 20

    train_batch_size = 16  # batch size
    test_batch_size = 60

    optimizer = 'sgd'

    use_gpu = True  # use GPU or not
    gpu_id = '0, 1'
    num_workers = 4  # how many workers for loading data
    print_freq = 100  # print info every N batch

    debug_file = '/tmp/debug'  # if os.path.exists(debug_file): enter ipdb
    result_file = 'result.csv'

    max_epoch = 10000
    lr = 5e-2  # initial learning rate
    lr_step = 10
    lr_decay = 0.6  # when val_loss increase, lr = lr*lr_decay, 0.95
    weight_decay = 5e-4
